{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Text Data Explanation Benchmarking: Emotion Multiclass Classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook demonstrates how to use the benchmark utility to benchmark the performance of an explainer for text data. In this demo, we showcase explanation performance for partition explainer on an Emotion Multiclass Classification model. The metrics used to evaluate are \"keep positive\" and \"keep negative\". The masker used is Text Masker.\n", "\n", "The new benchmark utility uses the new API with MaskedModel as wrapper around user-imported model and evaluates masked values of inputs." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import nlp\n", "import numpy as np\n", "import pandas as pd\n", "import scipy as sp\n", "import torch\n", "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n", "\n", "import shap\n", "import shap.benchmark as benchmark\n", "\n", "pd.set_option(\"display.max_columns\", None)\n", "pd.set_option(\"display.max_rows\", None)\n", "pd.set_option(\"max_colwidth\", None)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load Data and Model" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration default\n" ] } ], "source": [ "train, test = nlp.load_dataset(\"emotion\", split=[\"train\", \"test\"])\n", "\n", "data = {\"text\": train[\"text\"], \"emotion\": train[\"label\"]}\n", "\n", "data = pd.DataFrame(data)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"nateraw/bert-base-uncased-emotion\", use_fast=True)\n", "model = AutoModelForSequenceClassification.from_pretrained(\"nateraw/bert-base-uncased-emotion\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Class Label Mapping" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# set mapping between label and id\n", "id2label = model.config.id2label\n", "label2id = model.config.label2id\n", "labels = sorted(label2id, key=label2id.get)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define Score Function" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def f(x):\n", " tv = torch.tensor([tokenizer.encode(v, padding=\"max_length\", max_length=128, truncation=True) for v in x])\n", " attention_mask = (tv != 0).type(torch.int64)\n", " outputs = model(tv, attention_mask=attention_mask)[0].detach().numpy()\n", " scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T\n", " val = sp.special.logit(scores)\n", " return val" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create Explainer Object" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "explainers.Partition is still in an alpha state, so use with caution...\n" ] } ], "source": [ "explainer = shap.Explainer(f, tokenizer, output_names=labels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run SHAP Explanation" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\r", "Partition explainer: 5%|█▋ | 1/20 [00:00" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "sequential_perturbation = benchmark.perturbation.SequentialPerturbation(\n", " explainer.model, explainer.masker, sort_order, perturbation\n", ")\n", "xs, ys, auc = sequential_perturbation.model_score(shap_values, data[\"text\"][0:20])\n", "sequential_perturbation.plot(xs, ys, auc)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "sort_order = \"negative\"\n", "perturbation = \"keep\"" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAEGCAYAAABsLkJ6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAq7klEQVR4nO3deXxU5b3H8c9vsu8bCUtCNvYdAQXctV7rVmnVeq22dWu1u1dv71Vr2+ut1l5vtdZqb1u3WmvFulCklqpYiyuCgILIvgQIW0KALGRPnvvHDBggywBJTmbm+369ziszZ87MfE8C85vzPOc8jznnEBGRyOPzOoCIiHhDBUBEJEKpAIiIRCgVABGRCKUCICISoaK9DnA0+vXr5woLC72OISISUpYsWbLbOZd9+PqQKgCFhYUsXrzY6xgiIiHFzDa3t15NQCIiEUoFQEQkQqkAiIhEqJDqAxCR0NDU1ERpaSn19fVeR4ko8fHx5OXlERMTE9T2KgAi0u1KS0tJSUmhsLAQM/M6TkRwzlFRUUFpaSlFRUVBPUdNQCLS7err68nKytKHfy8yM7Kyso7qqEsFQER6hD78e9/R/s4jogC8u343v5m/wesYIiJ9SkQUgDfXlnPfa2vYvq/O6ygi0otmz56NmbF69eqD6+bPn89FF110yHbXXHMNL7zwAuDvwL7tttsYNmwYkyZNYvr06fz9738/4rWvv/56JkyYwPjx47nsssuoqanpMMeWLVtITk7mvvvuO7juwQcfZOzYsYwZM4Zf/vKXRzzn/vvvx8zYvXs3AHv37uULX/gC48eP56STTmLFihVH9btoT0QUgK9OL8A5x1ML2r0YTkTC1MyZMzn11FOZOXNm0M/50Y9+xI4dO1ixYgVLly5l9uzZVFdXH7HdAw88wLJly1i+fDn5+fk8/PDDHb7mLbfcwvnnn3/w/ooVK3j00UdZtGgRy5Yt4+WXX2b9+vUHH9+6dSuvvfYa+fn5B9fdc889TJw4keXLl/PUU09x0003Bb1PHYmIApCXkchnxwxg5qIt1DY2ex1HRHpBTU0N77zzDo8//jjPPvtsUM+pra3l0Ucf5aGHHiIuLg6A/v37c/nllx+xbWpqKuA/+6aurq7D9vfZs2dTVFTEmDFjDq5btWoVU6dOJTExkejoaM444wxmzZp18PGbb76Z//3f/z3kNVeuXMnZZ58NwMiRIykpKWHXrl1B7VdHIuY00OtOLeLvK3bylw+3cdXUAq/jiESM//7rJ6zcXtWtrzl6UCr/9bkxnW7z0ksvcd555zF8+HCysrJYsmQJkydP7vQ569evJz8//+CHe1euvfZa5s6dy+jRo7n//vuPeLympoZ7772XefPmHdL8M3bsWO644w4qKipISEhg7ty5TJky5WDu3NxcJkyYcMhrTZgwgVmzZnHaaaexaNEiNm/eTGlpKf379w8qa3si4ggAYEpBBuNy03jinU20tmoeZJFwN3PmTK644goArrjiioPNQB19Uz+Ws5Z+//vfs337dkaNGsWf//znIx6/8847ufnmm0lOTj5k/ahRo7j11ls599xzOe+885g4cSJRUVHU1tZyzz338JOf/OSI17rtttvYt28fEydO5KGHHuKEE04gKirqqDMfwjkXMsvkyZPd8XhxyVZXcOvLbv6asuN6HRHp3MqVKz19/4qKCpeQkODy8/NdQUGBy8vLc4MHD3atra3u448/dieffPIh23/uc59z8+fPd/v373eZmZmusrLyqN7vzTffdBdeeOER60899VRXUFDgCgoKXFpamsvIyHAPPfTQEdvdfvvt7te//rVbvny5y87OPvicqKgoN3jwYLdjx45Dtm9tbXUFBQXt5mzvdw8sdu18pkbMEQDAheMHkp0SxxPvbPI6ioj0oBdeeIGvfOUrbN68mZKSErZu3UpRURFvv/02w4YNY/v27axatQqAzZs3s2zZMiZOnEhiYiLXX389N910E42NjQCUl5fz/PPPH/L6zrmDnbbOOebMmcPIkSOPyPH2229TUlJCSUkJ//Zv/8YPfvADvvOd7wBQVlYG+M8QmjVrFldeeSXjxo2jrKzs4HPy8vJYunQpAwYMYN++fQczPfbYY5x++ulBN1V1JGL6AADioqP4yrQCfjFvLevLahiak9z1k0Qk5MycOZNbb731kHWXXnopM2fO5PTTT+fpp5/m2muvpb6+npiYGB577DHS0tIAuPvuu/nhD3/I6NGjiY+PJykp6YgmGeccV199NVVVVTjnmDBhAr/5zW8AmDNnDosXL263GefwPBUVFcTExPDrX/+a9PT0TrdftWoVV199NWbGmDFjePzxx4/yt3Ik8x8dhIYpU6a4450QZndNAyf/7A0uPzGPuz8/rpuSiUhbq1atYtSoUV7HiEjt/e7NbIlzbsrh20ZUExBAv+Q4ZkwcxItLtlFZ2+R1HBERz0RcAQC49pQi6ppaePaDLV5HERHxTEQWgNGDUplWnMlTCzbT3NLqdRyRsBRKzcvh4mh/554UADP7uZmtNrPlZvYXM0vv7QzXnlLEtn11vLby+K6kE5EjxcfHU1FRoSLQi1xgPoD4+Pign+PVWUDzgNudc81mdi9wO3BrF8/pVueM6s/gzAR+/+4mLhg3sDffWiTs5eXlUVpaSnl5uddRIsqBGcGC5UkBcM691ubu+8BlvZ0hymdcPb2Qu/+2ihXbKhmbm9bbEUTCVkxMTNCzUol3+kIfwHXAkWOtBpjZDWa22MwWd/e3iS9OGUxCTBR/eK+kW19XRCQU9FgBMLPXzWxFO8uMNtvcATQDf+rodZxzjzjnpjjnpmRnZ3drxrSEGC6ZlMtLy7azZ39jt762iEhf12MFwDl3jnNubDvLSwBmdg1wEXCV87Cn6OqTC2lsbuXPH2z1KoKIiCe8OgvoPOA/gYudc7VeZDhgeP8Uphdn8fT7OiVURCKLV30ADwMpwDwz+8jMfutRDsB/FLBtXx2vryrzMoaISK/y6iygoV68b0fOGZVDbnoCTy0o4byxA7yOIyLSK/rCWUCei47ycdW0fN7bUMG6XUfO/SkiEo5UAAL+dcpgYqKMZxZpfCARiQwqAAFZyXGcN3YgLy4ppb6pxes4IiI9TgWgjStPyqeqvpmXl+/wOoqISI9TAWhjWnEmxdlJPLNws9dRRER6nApAG2bGlSfls3TLPlbtqPI6johIj1IBOMxlk/OIjfbxzEJ1BotIeFMBOEx6YiwXjRvIXz7cxv6GZq/jiIj0GBWAdlx+4mBqGpp5c63GMheR8KUC0I4pBRmkxEfz5hoVABEJXyoA7YiO8nHq0H68ubZcU9qJSNhSAejAmSOy2VlVz9pdNV5HERHpESoAHTh9uH/ymflrNEKoiIQnFYAODExLYET/FHUEi0jYUgHoxJkjsvmgZI9OBxWRsKQC0IkzhmfT1OJ4b0OF11FERLqdCkAnphRmkhgbxZtr1Q8gIuFHBaATsdE+Th7Sj/lrdDqoiIQfFYAunDEim9K9dWzcvd/rKCIi3UoFoAtnDPOfDvru+t0eJxER6V4qAF0YnJlAbnoCC9QRLCJhRgWgC2bG9CFZLNhYQWur+gFEJHyoAARhenEW+2qbWL2z2usoIiLdRgUgCNOHZAGwYKOagUQkfKgABGFQegKFWYks2KCOYBEJHyoAQZo+JIuFm/bQon4AEQkTKgBBmlacRXV9M59sr/Q6iohIt1ABCNKBfgCNCyQi4UIFIEg5KfEMzUnW9QAiEjZUAI7C9OIsPijZQ1NLq9dRRESOmwrAUTh5SBa1jS0sL93ndRQRkeOmAnAUphb7+wHe37jH4yQiIsdPBeAoZCbFMnJACu/rgjARCQMqAEdpalEmi0v2qh9AREKeCsBRmlacRV1TC8tLdT2AiIQ2FYCjdFJRJgALN6kZSERCmwrAUcpKjmN4/2R1BItIyPOkAJjZXWa23Mw+MrPXzGyQFzmO1dSiLJboegARCXFeHQH83Dk33jk3EXgZ+LFHOY7JtOIs9je2sGKb+gFEJHR5UgCcc1Vt7iYBITXE5tRifz+AmoFEJJR51gdgZj81s63AVXRyBGBmN5jZYjNbXF5e3nsBO9EvOY6hOcnqCBaRkNZjBcDMXjezFe0sMwCcc3c45wYDfwK+09HrOOcecc5Ncc5Nyc7O7qm4R21acSYfbNpDs/oBRCRE9VgBcM6d45wb287y0mGb/gm4tKdy9JSpRYF+gO1VXW8sItIHeXUW0LA2d2cAq73IcTwO9AO8p2kiRSREedUH8D+B5qDlwLnATR7lOGY5KfGMHpjKP1eXeR1FROSYRHvxps65kGvyac/ZI3P4v/nr2VfbSHpirNdxRESOSpdHAGb2xWDWRaKzRubQ6uDNtX3j7CQRkaMRTBPQ7UGuizgTB6eTmRSrZiARCUkdNgGZ2fnABUCumf2qzUOpQHNPBwsFUT7jjOHZzF9TRkurI8pnXkcSEQlaZ0cA24HFQD2wpM0yB/hsz0cLDWeNzGFvbRMfbd3rdRQRkaPS4RGAc24ZsMzM/uSc0zf+DpwxLJson/HG6jImF2R6HUdEJGjB9AGsM7ONhy89nixEpCXGMDk/gzdWqyNYREJLMKeBTmlzOx74IqCvum2cNTKHe19ZzY7KOgamJXgdR0QkKF0eATjnKtos25xzvwQu7PlooePskTkAvKGzgUQkhHR5BGBmk9rc9eE/IvDkArK+anj/ZIqzk3hhSSlXTS3wOo6ISFCC+SC/v83tZqAEuLxH0oQoM+OqqQXc9fJKVmyrZGxumteRRES6FEwT0Fltln9xzn3dObemN8KFkssm5REf4+NPCzd7HUVEJCjBDAWRZWa/MrOlZrbEzB40s6zeCBdK0hJjuHjCIGZ/uJ2q+iav44iIdCmY00CfBcrxj9l/WeD2n3syVKj68rQC6ppa+MvSbV5HERHpUjAFYKBz7i7n3KbAcjfQv6eDhaLxeemMz0vj6fc341xITXMsIhEomALwmpldYWa+wHI58GpPBwtVX55awLqyGhZt0oTxItK3BVMAvg48AzQElmeBG82s2sw0H+JhPjdhEKnx0fxhQYnXUUREOhXMWUApzjmfcy4msPgC61Kcc6m9ETKUJMRGceXUAl5ZsZOte2q9jiMi0qFgzgL6RzDr5FPXnFyIz4zH39nkdRQRkQ51WADMLN7MMoF+ZpZhZpmBpRDI7bWEIWhAWjwXTxjEc4u3UlmrU0JFpG/q7AjgRvzj/48ElvLpfAAvAQ/3fLTQ9rXTiqltbOGZRVu8jiIi0q4OC4Bz7kHnXBHwfedcUZtlgnNOBaALowelcurQfjz53iYam1u9jiMicoRgzgKqNLOvHr70eLIw8LXTithV1cBfl233OoqIyBGCKQAntllOA+4ELu7BTGHjjOHZDO+fzBPvbtKFYSLS53Q5Gqhz7rtt75tZOv5rAaQLZsZXpxfyw9krWLplH5MLMryOJCJyUDBHAIfbDxR1d5Bw9YUTckmJi+YpXRgmIn1MMNcB/NXM5gSWl4E1wF96Plp4SIqL5rIpecz9eAdl1fVexxEROSiYCWHua3O7GdjsnCvtoTxh6SvTCvj9uyU8u2gr3/vMMK/jiIgAwTUBbQFSAssOffgfveLsZE4fns2fFm6mqUWnhIpI39DZlcCpZvYc8DpwXWB53cyeNzONAXSUrp5ewK6qBl77ZJfXUUREgM6PAH4FrASGOecucc5dAgwBPkZXAh+1M0fkMDgzgcfe2ahTQkWkT+isAJzinLvTOXewzcL5/QSY3vPRwkuUz7jx9CF8uGUfb64t9zqOiMgxnQYKYN2aIkJcPmUwuekJ/GLeWh0FiIjnOisA75nZj83skA97M/sRsKBnY4Wn2Ggf3/vMUJaXVvL6qjKv44hIhOusAHwXGAesN7MXA8sGYALwnV5JF4YumZRHQVYiv5i3ltZWHQWIiHc6Gw20yjn3ReBc4MnAcq5z7jLnXGXvxAs/MVE+vnf2MFbtqOLVT3Z6HUdEIlgwU0JucM79NbBs6I1Q4e7zJ+RSnJ3E/fPW0qzrAkTEI8faCSzHIcpn/OdnR7K+rIaZmjBGRDziaQEws383M2dm/bzM4YXPjunP1KJMfjFvLZV1mjZSRHpfZ1cCZ3a2HO8bm9lg/P0LEfkV2Mz40UWj2VfXxMNvrPM6johEoM4Gg1sCONo/598Bxcf53g8A/4l/juGINDY3jS9OzuPJ90q4amoBhf2SvI4kIhGks7OAipxzxYfNB3xgOa4PfzObAWxzzi0LYtsbzGyxmS0uLw+/K2i/f+4IYqN83PXySl0cJiK9Kpj5AMzMvhy4AAwzyzezk4J43utmtqKdZQbwA+DHwQR0zj3inJvinJuSnZ0dzFNCSk5qPLecO4J/rC7jl6+rKUhEek8w8wH8H9AKnA3cBVQDL+KfI7hDzrlz2ltvZuPwzyi2LHCRcR6w1MxOcs5F5Inx151SyOodVTz4j3UUZCVyyaQ8ryOJSAQIpgBMdc5NMrMPAZxze80s9ljf0Dn3MZBz4L6ZlQBTnHO7j/U1Q52Z8dMvjKN0bx23vric3PQEphZneR1LRMJcMKeBNplZFP6OX8wsG/8RgXSj2Ggfv/3yZPIzE/n6U4tZXrrP60giEuaCKQC/wj8HcI6Z/RR4B7inuwI45woj+dt/W2mJMfzhupNIS4zhqscW8uGWvV5HEpEwFsxQEH/Cf7rmz4AdwOedc8/3dLBIlZeRyLM3TCcjMZavPr6IJZtVBESkZwR1IRhQBswEngF2dceFYNKx3PQE/nzjNLKSY7nmiUVs3VPrdSQRCUOdHQEsARYHfpYDa4F1gdtLej5aZBuYlsAfr59Kq3PcNmu5rhEQkW7X5YVg+CeF/5xzrp9zLgu4CHittwJGssGZidx+wSjeXV/BzEVbvY4jImEmmE7gac65uQfuOOf+Dpzcc5GkrStPyufkIVncM3cV2/bVeR1HRMJIMAVgu5n90MwKA8sdwPaeDiZ+Pp9x76Xj/U1BL6opSES6TzAF4EtANv5TQf+C/yKuL/VkKDnU4MxEbjt/JG+v283fPt7hdRwRCRNdXgnsnNsD3GRmKf67rqbnY8nhrppawMxFW/nZ3NWcM6o/8TFRXkcSkRAXzGBw4wLDQKwAPjGzJWY2tuejSVtRPuO/PjeabfvqeOStjV7HEZEwEEwT0O+AW5xzBc65AuDfgUd6Npa0Z1pxFheMG8D/zV/PdnUIi8hxCqYAJDnn/nngjnNuPqCZSzxy+/mjaHVw7yurvY4iIiEumAKw0cx+1OYsoB8CaoPwyODMRG48vZiXPtrOwo0VXscRkRAWTAG4Dv9ZQLMCS3ZgnXjkW2cOZXBmArfP+pj6phav44hIiApmMLi9zrnvOecmBZabnHMaocxDCbFR3POFcWzcvZ+H31jvdRwRCVEdngZqZnM6e6Jz7uLujyPBOm1YNpdMyuW3b27gwvEDGTUw1etIIhJiOrsOYDqwFf8ooAsB65VEErQfXTiaN9eUc9usj5n1zZOJ8ulPJCLB66wJaAD+ydvHAg8C/wLsds696Zx7szfCSecykmL58edGs2zrPmYu2uJ1HBEJMZ2NBtrinHvFOXc1MA1YD8w3s+/0Wjrp0sUTBnFSUSa/mLeWqvomr+OISAjptBPYzOLM7BLgaeDbfDo9pPQRZsaPLxrN3tpGHvrHOq/jiEgI6WxGsKeABcAk4L+dcyc65+5yzm3rtXQSlLG5aXxxch5PvlfCpt37vY4jIiGisyOALwPDgJuA98ysKrBUm1lV78STYH3/syOIjfJxz9xVXkcRkRDRWR+AzzmXElhS2ywpzjmdc9jH5KTE8+2zhzJv5S7eWbfb6zgiEgKCuRJYQsR1pxSRn5nInX/9hKaWVq/jiEgfpwIQRuJjovjRRaNZX1bDUws2ex1HRPo4FYAwc86oHE4fns0v562lvLrB6zgi0oepAIQZM//EMfXNLfyvhowWkU6oAIShIdnJXHdKEc8vKWXJ5j1exxGRPkoFIEx99zPDyE1P4KZnP6KyVlcIi8iRVADCVHJcNA9feQI7K+v5jxeW4ZzzOpKI9DEqAGHshPwMbjt/JK+t3MWT75V4HUdE+hgVgDB3/alFnDMqh3vmruKjrfu8jiMifYgKQJgzM+774gT6p8bztT8sZuueWq8jiUgfoQIQAdITY3ny2hNpbG7h2ic/UKewiAAqABFjaE4Kj3x1Clsqarnx6cU0NGsyeZFIpwIQQaYVZ/HzL47n/Y17+PfnltHSqjODRCJZZ3MCSxiaMTGXsqoGfjp3Fclx0fzsknGYaS5hkUikAhCBvn56MVX1TTz0xnqS46K548JRKgIiEciTAmBmdwJfB8oDq37gnJvrRZZIdcu/DKe6vpnH3tlElM+49byR+HwqAiKRxMsjgAecc/d5+P4R7cBcws2trfzurY1sKK/hgX+dSEp8jNfRRKSXqBM4gvl8xl0zxvKTGWP455pyLvm/91hfVu11LBHpJV4WgO+Y2XIze8LMMjrayMxuMLPFZra4vLy8o83kGJkZX51eyB+vO4nymgbO+cVbXP3EIl5fuUtnCYmEOeupQcLM7HVgQDsP3QG8D+wGHHAXMNA5d11XrzllyhS3ePHibs0pnyqrrueZhVuYuWgLu6oa6JccyxnDc/jMqBxOGdKPtEQ1D4mEIjNb4pybcsR6r0eJNLNC4GXn3NiutlUB6B1NLa38Y9Uu5n68k/lryqiqbwZgcGYCYwamMTE/nTNHZDOif4rOHhIJAX2qAJjZQOfcjsDtm4GpzrkrunqeCkDva25pZemWfSzevIdPtlXxyfZKSir84wnlpifwL6P7c/2pRQzOTPQ4qYh0pK8VgD8CE/E3AZUANx4oCJ1RAegbdlbW8881Zbyxuow315TT4hwzJg7iW2cOZWhOstfxROQwfaoAHCsVgL5nZ2U9j7y1kWcWbaahuZVzR/fnG2cM4YT8Dvv1RaSXqQBIj6qoaeDJ90p4asFmKuuamFqUyQ2nF3PWiBxdYCbiMRUA6RU1Dc08u2gLj7+ziR2V9RRnJ3H9qUV8fmIuSXEaeUTECyoA0quaWlqZ+/EOHn17Iyu2VZEUG8XFEwfxpZPyGZ+X7nU8kYiiAiCecM6xZPNenv1gKy8v3059UytTCjK48YwhfGakmodEeoMKgHiuqr6JF5eU8tjbm9i2r44h2UlcObWAGRMH0S85zut4ImFLBUD6jOaWVv728Q4ef2cTy0srifYZZ47I4dJJuZw9Koe46CivI4qElY4KgHrlpNdFR/mYMTGXGRNzWburmheXlDLrw228vmoXqfHRXDh+IFecmM+EweleRxUJazoCkD6huaWVdzdUMPvDbbyyYid1TS1MGJzO1dMLuHD8QB0ViBwHNQFJyKiub2LW0m38YUEJG8v3Mygtnu99ZhiXTs4jJkojmIscLRUACTnOOd5at5sH5q3lo637KMxK5D8+O5ILxg3QIHQiR6GjAqCvU9JnmRlnDM/mL986mce+OoX4mCi+/cxSrnpsIet2aeIakeOlAiB9nplxzuj+/O17p3HX58fyyfYqzn/wbX42dxV1jS1exxMJWSoAEjKifMZXphXwz++fyWWT8/jdWxs578G3eG/Dbq+jiYQkFQAJOZlJsfzPpeOZ+fVpGHDlowu55bmP2Lqn1utoIiFFBUBC1vQhWbzyb6fzjTOG8PLyHZx133xue3E5m3bv9zqaSEjQWUASFnZW1vOb+euZuWgrjS2t5KYnMK04i6lFmUwuzKC4X5LOHJKIpdNAJSLsqKzj1RU7eX/jHhZuqmBvbRMAGYkxTMrPYEphJicWZjAuL00Xl0nEUAGQiNPa6thQXsPSLXtZsnkvizfvZWO5v3koLtrHpPwMphVnMa04kwmD04mPUUGQ8KQCIIJ/5rLFm/eyaNMe3t9YwcodVTgHsdE+JuWnM604i1OG9mPi4HRddSxhQwVApB2VtU0s3FTBwk3+JqNPtvsLQlJsFFOLszhzRDZnj8whLyPR66gix0wFQCQIlbVNLNhYwbvrd/PWunI2V/hPLR05IIULxg1kxsRBFGQleZxS5OioAIgcgw3lNbyxqox5K3exqGQPACfkp/P5iblcNH4gWZrIRkKACoDIcdq+r445y7Yz+8NtrN5ZTZTPP1bReWMHcObwbHJS472OKNIuFQCRbrR6ZxWzP9zOnI+2sb2yHoAxg1I5eUgWk/IzOCE/gwFpKgjSN6gAiPQA5xyrdlTzzzVlzF9TxrKtlTS2tALQPzWOcbnpjMtNY8SAZHLTE8nNSCAjMUYXpUmvUgEQ6QUNzS2s2lHN0s17+XhbJctL97Fx937a/jeLjfKRmhBNanwMWcmxFGYlUZSdRGFWEoMzEhmcmUBagoqEdB/NCSzSC+Kio5g4OJ2JbeYzrmlopmT3frbtq2Pb3jrKqhuorm+isq6JsuoG3lxbzvNLSg95neS4aPqnxjEgLZ7+KfGkJ8aSkRhDemIMCbHRJMREER/jIz4mithoH3HRPmKjfcRGBX5G+4iL+vQxn0/FRI6kAiDSw5Ljohmbm8bY3LQOt6lpaGZzxX627qmjdG8tpXvr2FlZz67qehZu2sO+2kb2H8fcB7HRPhJiooiJ8hHlg2ifjyifEe0zfIGfsdE+YqJ8JMZGkZbgLzYZibFkp8SRkxJHZlIcSXFRJMdFkxQXTXJcNHHRPh2phDAVAJE+IDkumjGD0hgzqOMi0djcSmVdE/VNLdQ1tVDb2EJjcyuNza00NAdut7TS0NRKQ0vrwcfqm1qob26hvrGFplZHS4ujudXR6vw/m1taD/5sanFU1zezbW8d++qa2FfbSGsnrcRRPiMpNoqE2CgSY6OJj4kiJS6a5PhoUuKjyUiMJTPJv/RLjqVfchzZKXEMSk/QldZ9gAqASIiIjfaRndK71x20tDr27G+kvLqBPfsbqWlopqahmf2Bn7WNzexvaDlYlPY3tLC/oZmy6no2lDezZ38j1fXNR7xutM8o7JfEsJxkhuUkM7R/CsNykinql6QxmXqRCoCIdCjKZ2SnxB1X4WlsbmVvbSO7axoor26grLqBTbv3s76shtU7q3n1k50HjzJ8BgVZ/sIwNCeZIdn+n0NzkkmK08dVd9NvVER6VGy0j/6p8fTv4EK5+qYWSir2s3ZXDevLali3q5q1u6p5Y3UZzW3anwZnJjCifwrD+qccLAxDspNIiY/prV0JOyoAIuKp+JgoRg5IZeSA1EPWN7W0srmilvVlNawvq2b1zmrW7Kxm/pryQwpD/9S4gwVh1MBURg1MZUT/FBJi1ZTUFRUAEemTYqJ8B5t/YMDB9U0trWzZU8uGshrWl9ewsXw/G8prmLV0GzUNmwEwg/zMRIb3T2FE/xSGD/D/LOqXRGy0Op8PUAEQkZASE+VjSLa/f+DcNutbWx2le+tYuaOK1TurWLerhjWBpqSWwBFDtM8YmpPMyAEpjByYyuiBqYwelEq/CB3UTwVARMKCz2fkZyWSn5XIeWM/PWJoaG5hY/l+1u76tBlp0aY9zP5o+8FtclLiGD0olTGDUgOn46YyOCMx7C+gUwEQkbAWFx11sG9gRpv1+2obWbmjipXbA8uOKt5et/vg0UJS7KfPGz0oPPsWPBsLyMy+C3wbaAH+5pz7z66eo7GARKQn1Te1sHZXNat2fFoUVu2opqbBfy2DGRRlJTFyYApDs5Mp7JdEQVYSuekJZCXH9tmL2/rUWEBmdhYwA5jgnGswsxwvcoiItBUfE8X4vHTG56UfXHd438LqHdV8sr2KV1bsPOIq6fTEGJLjog8OsWEQuMraf+W1f4Fj+d79qysmcvLQfse1f4fzqgnom8D/OOcaAJxzZR7lEBHpVEd9C43NrWzdW0vJ7v3srKpnd7X/Yrf9Dc20BIbZwEF0lBHt8+Ez/4V1ZsaxDJ/UrweuAveqAAwHTjOznwL1wPedcx+0t6GZ3QDcAJCfn997CUVEOhEb/enZSKGqxwqAmb1O25N3P3VH4H0zgWnAicBzZlbs2umQcM49AjwC/j6AnsorIhJpeqwAOOfO6egxM/smMCvwgb/IzFqBfkB5T+UREZFDedVlPRs4C8DMhgOxwG6PsoiIRCSv+gCeAJ4wsxVAI3B1e80/IiLSczwpAM65RuDLXry3iIj49c2rFkREpMepAIiIRCgVABGRCOXZWEDHwszKgc3H+PR+RN6ZRtrnyKB9jgzHs88Fzrnsw1eGVAE4Hma2uL3BkMKZ9jkyaJ8jQ0/ss5qAREQilAqAiEiEiqQC8IjXATygfY4M2ufI0O37HDF9ACIicqhIOgIQEZE2VABERCJU2BUAMzvPzNaY2Xozu62dx+PM7M+BxxeaWaEHMbtVEPt8i5mtNLPlZvYPMyvwImd36mqf22x3qZk5MwvpUwaD2V8zuzzwd/7EzJ7p7YzdLYh/1/lm9k8z+zDwb/sCL3J2JzN7wszKAgNltve4mdmvAr+T5WY26bje0DkXNgsQBWwAivEPMb0MGH3YNt8Cfhu4fQXwZ69z98I+nwUkBm5/MxL2ObBdCvAW8D4wxevcPfw3HgZ8CGQE7ud4nbsX9vkR4JuB26OBEq9zd8N+nw5MAlZ08PgFwN8Bwz+h1sLjeb9wOwI4CVjvnNvo/COOPot/8vm2ZgB/CNx+AfiM2bHM0NlndLnPzrl/OudqA3ffB/J6OWN3C+bvDHAXcC/+aUdDWTD7+3Xg1865vRAW82wHs88OSA3cTgO292K+HuGcewvY08kmM4CnnN/7QLqZDTzW9wu3ApALbG1zvzSwrt1tnHPNQCWQ1SvpekYw+9zW9fi/QYSyLvc5cGg82Dn3t94M1kOC+RsPB4ab2btm9r6Znddr6XpGMPt8J/BlMysF5gLf7Z1onjra/++d8mpCGPGAmX0ZmAKc4XWWnmRmPuAXwDUeR+lN0fibgc7Ef4T3lpmNc87t8zJUD/sS8KRz7n4zmw780czGOudavQ4WKsLtCGAbMLjN/bzAuna3MbNo/IeOFb2SrmcEs8+Y2TnAHcDFzrmGXsrWU7ra5xRgLDDfzErwt5XOCeGO4GD+xqXAHOdck3NuE7AWf0EIVcHs8/XAcwDOuQVAPP4B08JZUP/fgxVuBeADYJiZFZlZLP5O3jmHbTMHuDpw+zLgDRfoXQlRXe6zmZ0A/A7/h3+otw1DF/vsnKt0zvVzzhU65wrx93tc7Jxb7E3c4xbMv+vZ+L/9Y2b98DcJbezFjN0tmH3eAnwGwMxG4S8A5b2asvfNAb4aOBtoGlDpnNtxrC8WVk1AzrlmM/sO8Cr+swiecM59YmY/ARY75+YAj+M/VFyPv7PlCu8SH78g9/nnQDLwfKC/e4tz7mLPQh+nIPc5bAS5v68C55rZSqAF+A/nXMge2Qa5z/8OPGpmN+PvEL4mxL/MYWYz8RfyfoG+jf8CYgCcc7/F39dxAbAeqAWuPa73C/Hfl4iIHKNwawISEZEgqQCIiEQoFQARkQilAiAiEqFUAEREIpQKgPQ5ZtZiZh+Z2Qoze97MEj3IcKaZndzBY3ea2fcPW1cSOP++V5nZ/GO9wM3MnjSzy7o7k4QOFQDpi+qccxOdc2OBRuAbwTwpcGV3dzkTaLcAiIQLFQDp694GhppZUmCs9EWB8d9nAJjZNWY2x8zeAP5hZslm9nsz+zgwXvqlge3ONbMFZrY0cFSRHFhfYmb/HVj/sZmNNP8cEd8Abg4ciZwWbFgzKzSzVWb2aGBc/tfMLCHw2Hwze8DMFge2OdHMZpnZOjO7u81rzDazJYHn3xBYFxX4xr4ikPPmw97XF3j87sC2PzezDwK/gxsD25iZPWz+MfZfB3KO/c8i4SCsrgSW8BL4Rn8+8Ar+cYzecM5dZ2bpwKLAhxj4x08f75zbY2b34r88flzgNTICTTM/BM5xzu03s1uBW4CfBJ6/2zk3ycy+BXzfOfc1M/stUOOcu+8Yog8DvuSc+7qZPQdcCjwdeKzROTfFzG4CXgIm478ifYOZPRC4eve6wL4kAB+Y2YtAIZAbOCoi8Ds4IBr4E/4x5H8aKBqVzrkTzSwOeNfMXgNOAEbgHzu/P7ASeOIY9k/ChAqA9EUJZvZR4Pbb+IfveA+4uE3bezyQH7g9zzl3YAz1c2gzvIdzbq+ZXYT/Q+/dwFAYscCCNu83K/BzCXBJEPk6unz+wPpNzrkD+Zfg//A+4MAwFR8DnxwYx8XMNuIf5KsC+J6ZfSGw3WD8BWUNUGxmDwF/A15r85q/A55zzv00cP9cYHyb9v20wGucDsx0zrUA2wNHTRLBVACkL6pzzk1su8L8n9yXOufWHLZ+KrC/i9cz/EXiSx08fmB01BaC+z9RARw+CUcKsC/ws+1oqy1AQjvv1XrYdq1AtJmdib+ITXfO1ZrZfCA+UMgmAJ/F3zx1OXBd4LnvAWeZ2f3OuXr8+/td59yrbQNaGEyZKN1LfQASKl4FvhsoBAdGOG3PPODbB+6YWQb+0UBPMbOhgXVJZja8i/erxv9h3p638B+NpARe7xJgWeCb9fFKA/YGPvxH4h/K+sAInz7n3Iv4m7PazgX7OP5Bwp4LNJu9CnzTzGICzx1uZkmB3P8a6CMYiH+qUIlgKgASKu7CPyricjP7JHC/PXcDGYHO0mXAWc65cvyTw8w0s+X4m39GdvF+fwW+0F4nsHNuOfAw8E6gqeobwNeObbeO8Ar+I4FVwP/gL17gn/VpfuD9ngZuPyzTL/DPCfxH4DH87ftLzT+5+O/wH9n8BVgXeOwpDm0Gkwik0UBFRCKUjgBERCKUCoCISIRSARARiVAqACIiEUoFQEQkQqkAiIhEKBUAEZEI9f800OBvwrCUvwAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "sequential_perturbation = benchmark.perturbation.SequentialPerturbation(\n", " explainer.model, explainer.masker, sort_order, perturbation\n", ")\n", "xs, ys, auc = sequential_perturbation.model_score(shap_values, data[\"text\"][0:20])\n", "sequential_perturbation.plot(xs, ys, auc)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.9" } }, "nbformat": 4, "nbformat_minor": 4 }